## The Hadamard decoder

from PyM import *


def hadamard_decoder(y,H):
    if len(y) != len(H):
        return "hadamard_decoder error: length mismatch"
    r = len(H)//2
    z = unbin_(y)
    for h in H:
        c = dot(z,h)
        if abs(c) > r:
            if c > 0:
                return bin_(h)
            else:
                return bin_(-h)
    return "hadamard_decoder error: undecodable"
    
    

## Auxiliary functions

def flip(x,j):    
    n = len(x)
    y = vector(Z_,n)
    for i in range(n):
        y[i] = x[i]
    if isinstance(j,int):
        if 0<=j and j<n:
            y[j] = 1-y[j]
        else: 
            return "flip error: index out of range"
    if isinstance(j,list):
        for k in j:
            if 0<=k and k<n:
                y[k] = 1-y[k]
            else:
                return "flip error: index out of range"
    return y
    return "flip error: wrong data"


def bin_(x):
    v = vector(Z_,len(x))
    for j in range(len(v)):
        if x[j] == -1:
            v[j] = 0
        else:
            v[j] = x[j]
    return v

def unbin_(x):
    v = vector(Z_,len(x))
    for j in range(len(v)):
        if x[j]==0:
            v[j] = -1
        else:
            v[j] = x[j]
    return v

    

## Examples

H=hadamard_matrix(Zn(5))  # 12 x 12, corrects t =12/4-1 = 2 errors

C = hadamard_code(H)


# Example of no error 
''' 
x = C[4]
show(x) 
show(hadamard_decoder(x,H)) 
'''

# Examples of 1, 2 or 3 errors

# x = C[4]; y = flip(x,1)                 

# x = C[14]; y = flip(x,2)                 

# x = C[8]; y = flip(x,[1,6])            

x = C[17]; y = flip(x,[6,11])            

# x = C[17]; y = flip(x,[1,6,8]) 

show(y)
show(x)
show(hadamard_decoder(y,H)) 
